package com.masonluo.mlonlinejudge.acl.core;

import com.masonluo.mlonlinejudge.acl.annotations.Filter;
import com.masonluo.mlonlinejudge.acl.annotations.RequireAuthentication;
import com.masonluo.mlonlinejudge.acl.annotations.RequireRoles;
import com.masonluo.mlonlinejudge.acl.core.filter.AuthenticateFilter;
import com.masonluo.mlonlinejudge.acl.utils.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.annotation.AnnotatedGenericBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author masonluo
 * @date 2021/4/17 3:20 下午
 */
@Component
public class FilterBeanPostProcessor implements BeanPostProcessor, BeanFactoryAware {

    private final Logger log = LoggerFactory.getLogger(FilterBeanPostProcessor.class);

    private final Map<String, AuthenticateFilter> processors = new ConcurrentHashMap<>(64);

    private ConfigurableListableBeanFactory beanFactory;

    private NameGenerator generator = new FilterNameGenerator();

    private BeanNameGenerator beanNameGenerator = new DefaultBeanNameGenerator();

    private AuthenticateFilterChain chain = new AuthenticateFilterChain();

    private AccessControlFailReturn failReturn;

    private AuthenticationIdentificationAchiever authenticationMetadataAchiever;

    public FilterBeanPostProcessor() {
    }

    @PostConstruct
    public void init() {
        loadFilter();
        loadGlobalAccessControlFailReturn();
        loadAuthenticationMetadataAchiever();
    }

    private void loadAuthenticationMetadataAchiever() {
        if (beanFactory == null) {
            String msg = "BeanFactory should not be null, please check";
            log.error(msg);
            throw new IllegalArgumentException(msg);
        }
        String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, AuthenticationIdentificationAchiever.class);
        if (beanNames.length > 1) {
            // 只允许注册一个实例
            throw new IllegalArgumentException("Only one AuthenticationMetadataAchiever.class could be register in spring container");
        }
        String beanName;
        if (ObjectUtils.isEmpty(beanNames)) {
            beanName = registerDefaultAuthenticationMetadataAchiever();
        } else {
            beanName = beanNames[0];
        }
        authenticationMetadataAchiever = beanFactory.getBean(beanName, AuthenticationIdentificationAchiever.class);
    }

    public void loadFilter() {
        if (beanFactory == null) {
            String msg = "BeanFactory should not be null, please check";
            log.error(msg);
            throw new IllegalArgumentException(msg);
        }
        String[] beanNames = BeanFactoryUtils.beanNamesForAnnotationIncludingAncestors(beanFactory, Filter.class);
        if (ObjectUtils.isEmpty(beanNames)) {
            log.info("Doesn't find any filter for acl");
            return;
        }
        for (String beanName : beanNames) {
            Object bean = beanFactory.getBean(beanName);
            if (!(bean instanceof AuthenticateFilter)) {
                throw new IllegalArgumentException("The Filter Annotation should annotate on AclProcess class");
            }
            AuthenticateFilter processor = (AuthenticateFilter) bean;
            String name = generator.getName(processor.getClass());
            processors.put(name, processor);
            chain.addLast(processor);
        }
    }

    public void loadGlobalAccessControlFailReturn() {
        if (beanFactory == null) {
            String msg = "BeanFactory should not be null, please check";
            log.error(msg);
            throw new IllegalArgumentException(msg);
        }
        String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, AccessControlFailReturn.class);
        if (ObjectUtils.isEmpty(beanNames)) {
            // 加载默认的全局返回
            failReturn = new DefaultAccessControlFailReturn();
            return;
        }
        if (beanNames.length > 1) {
            // 全局只能有一个
            throw new IllegalArgumentException("Find more than one class [" + AccessControlFailReturn.class.getName() + "], please confirm");
        }
        failReturn = beanFactory.getBean(beanNames[0], AccessControlFailReturn.class);
    }

    private String registerDefaultAuthenticationMetadataAchiever() {
        if (!(beanFactory instanceof BeanDefinitionRegistry)) {
            throw new IllegalArgumentException("Can't acquire the bean definition");
        }
        BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
        BeanDefinition definition = new AnnotatedGenericBeanDefinition(DefaultAuthenticationMetadataAchiever.class);
        String beanName = beanNameGenerator.generateBeanName(definition, registry);
        registry.registerBeanDefinition(beanName, definition);
        return beanName;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if (!(ObjectUtils.hasAnnotationIncludingMethod(bean, RequireRoles.class)
                || ObjectUtils.hasAnnotationIncludingMethod(bean, RequireAuthentication.class))) {
            return bean;
        }
        return wrapBeanForAccessControl(bean);
    }

    private Object wrapBeanForAccessControl(Object bean) {
        return bean;
    }

    @RequireRoles
    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        if (!(beanFactory instanceof ConfigurableListableBeanFactory)) {
            throw new IllegalArgumentException(
                    "FilterBeanPostProcessor requires a ConfigurableListableBeanFactory: " + beanFactory);
        }
        this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
    }

}
